// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package logicalop

import (
	"bytes"
	"fmt"

	"github.com/pingcap/tidb/pkg/expression"
	"github.com/pingcap/tidb/pkg/planner/core/base"
	fd "github.com/pingcap/tidb/pkg/planner/funcdep"
	"github.com/pingcap/tidb/pkg/planner/property"
	"github.com/pingcap/tidb/pkg/planner/util/optimizetrace"
	"github.com/pingcap/tidb/pkg/planner/util/optimizetrace/logicaltrace"
	"github.com/pingcap/tidb/pkg/planner/util/utilfuncp"
	"github.com/pingcap/tidb/pkg/types"
	"github.com/pingcap/tidb/pkg/util/dbterror/plannererrors"
	"github.com/pingcap/tidb/pkg/util/plancodec"
	"github.com/pingcap/tipb/go-tipb"
)

// LogicalExpand represents a logical Expand OP serves for data replication requirement.
type LogicalExpand struct {
	LogicalSchemaProducer `hash64-equals:"true"`

	// distinct group by columns. (maybe projected below if it's a non-col)
	DistinctGroupByCol  []*expression.Column `hash64-equals:"true"`
	DistinctGbyColNames []*types.FieldName
	// keep the old gbyExprs for resolve cases like grouping(a+b), the args:
	// a+b should be resolved to new projected gby col according to ref pos.
	DistinctGbyExprs []expression.Expression `hash64-equals:"true"`

	// rollup grouping sets.
	DistinctSize       int                     `hash64-equals:"true"`
	RollupGroupingSets expression.GroupingSets `hash64-equals:"true"`
	RollupID2GIDS      map[int]map[uint64]struct{}
	RollupGroupingIDs  []uint64

	// The level projections is generated from grouping sets，make execution more clearly.
	LevelExprs [][]expression.Expression `hash64-equals:"true"`

	// The generated column names. Eg: "grouping_id" and so on.
	ExtraGroupingColNames []string

	// GroupingMode records the grouping id allocation mode.
	GroupingMode tipb.GroupingMode

	// The GID and GPos column generated by logical expand if any.
	GID      *expression.Column `hash64-equals:"true"`
	GIDName  *types.FieldName
	GPos     *expression.Column `hash64-equals:"true"`
	GPosName *types.FieldName
}

// Init initializes LogicalProjection.
func (p LogicalExpand) Init(ctx base.PlanContext, offset int) *LogicalExpand {
	p.BaseLogicalPlan = NewBaseLogicalPlan(ctx, plancodec.TypeExpand, &p, offset)
	return &p
}

// *************************** start implementation of logicalPlan interface ***************************

// HashCode inherits BaseLogicalPlan.LogicalPlan.<0th> implementation.

// PredicatePushDown implements base.LogicalPlan.<1st> interface.
func (p *LogicalExpand) PredicatePushDown(predicates []expression.Expression, opt *optimizetrace.LogicalOptimizeOp) (ret []expression.Expression, retPlan base.LogicalPlan) {
	// Note that, grouping column related predicates can't be pushed down, since grouping column has nullability change after Expand OP itself.
	// condition related with grouping column shouldn't be pushed down through it.
	// currently, since expand is adjacent to aggregate, any filter above aggregate wanted to be push down through expand only have two cases:
	// 		1. agg function related filters. (these condition is always above aggregate)
	// 		2. group-by item related filters. (there condition is always related with grouping sets columns, which can't be pushed down)
	// As a whole, we banned all the predicates pushing-down logic here that remained in Expand OP, and constructing a new selection above it if any.
	remained, child := p.BaseLogicalPlan.PredicatePushDown(nil, opt)
	return append(remained, predicates...), child
}

// PruneColumns implement the base.LogicalPlan.<2nd> interface.
// logicExpand is built in the logical plan building phase, where all the column prune is not done yet. So the
// expand projection expressions is meaningless if it built at that time. (we only maintain its schema, while
// the level projection expressions construction is left to the last logical optimize rule)
//
// so when do the rule_column_pruning here, we just prune the schema is enough.
func (p *LogicalExpand) PruneColumns(parentUsedCols []*expression.Column, opt *optimizetrace.LogicalOptimizeOp) (base.LogicalPlan, error) {
	// Expand need those extra redundant distinct group by columns projected from underlying projection.
	// distinct GroupByCol must be used by aggregate above, to make sure this, append DistinctGroupByCol again.
	parentUsedCols = append(parentUsedCols, p.DistinctGroupByCol...)
	used := expression.GetUsedList(p.SCtx().GetExprCtx().GetEvalCtx(), parentUsedCols, p.Schema())
	prunedColumns := make([]*expression.Column, 0)
	for i := len(used) - 1; i >= 0; i-- {
		if !used[i] {
			prunedColumns = append(prunedColumns, p.Schema().Columns[i])
			p.Schema().Columns = append(p.Schema().Columns[:i], p.Schema().Columns[i+1:]...)
			p.SetOutputNames(append(p.OutputNames()[:i], p.OutputNames()[i+1:]...))
		}
	}
	logicaltrace.AppendColumnPruneTraceStep(p, prunedColumns, opt)
	// Underlying still need to keep the distinct group by columns and parent used columns.
	var err error
	p.Children()[0], err = p.Children()[0].PruneColumns(parentUsedCols, opt)
	if err != nil {
		return nil, err
	}
	return p, nil
}

// FindBestTask inherits BaseLogicalPlan.LogicalPlan.<3rd> implementation.

// BuildKeyInfo inherits BaseLogicalPlan.LogicalPlan.<4th> implementation.

// PushDownTopN inherits BaseLogicalPlan.LogicalPlan.<5th> implementation.

// DeriveTopN inherits BaseLogicalPlan.LogicalPlan.<6th> implementation.

// PredicateSimplification inherits BaseLogicalPlan.LogicalPlan.<7th> implementation.

// ConstantPropagation inherits BaseLogicalPlan.LogicalPlan.<8th> implementation.

// PullUpConstantPredicates inherits BaseLogicalPlan.LogicalPlan.<9th> implementation.

// RecursiveDeriveStats inherits BaseLogicalPlan.LogicalPlan.<10th> implementation.

// DeriveStats inherits BaseLogicalPlan.LogicalPlan.<11th> implementation.

// ExtractColGroups inherits BaseLogicalPlan.LogicalPlan.<12th> implementation.

// PreparePossibleProperties inherits BaseLogicalPlan.LogicalPlan.<13th> implementation.

// ExhaustPhysicalPlans implements base.LogicalPlan.<14th> interface.
func (p *LogicalExpand) ExhaustPhysicalPlans(prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) {
	return utilfuncp.ExhaustPhysicalPlans4LogicalExpand(p, prop)
}

// ExtractCorrelatedCols implements base.LogicalPlan.<15th> interface.
func (p *LogicalExpand) ExtractCorrelatedCols() []*expression.CorrelatedColumn {
	// if p.LevelExprs is nil, it means the GenLevelProjections has not been called yet,
	// which is done in logical optimizing phase. While for building correlated subquery
	// plan, the ExtractCorrelatedCols will be called once after building, so we should
	// distinguish the case here.
	if p.LevelExprs == nil {
		// since level projections generation don't produce any correlated columns, just
		// return nil.
		return nil
	}
	corCols := make([]*expression.CorrelatedColumn, 0, len(p.LevelExprs[0]))
	for _, lExpr := range p.LevelExprs {
		for _, expr := range lExpr {
			corCols = append(corCols, expression.ExtractCorColumns(expr)...)
		}
	}
	return corCols
}

// MaxOneRow inherits BaseLogicalPlan.LogicalPlan.<16th> implementation.

// Children inherits BaseLogicalPlan.LogicalPlan.<17th> implementation.

// SetChildren inherits BaseLogicalPlan.LogicalPlan.<18th> implementation.

// SetChild inherits BaseLogicalPlan.LogicalPlan.<19th> implementation.

// RollBackTaskMap inherits BaseLogicalPlan.LogicalPlan.<20th> implementation.

// CanPushToCop inherits BaseLogicalPlan.LogicalPlan.<21st> implementation.

// ExtractFD implements the base.LogicalPlan.<22nd> interface, extracting the FD from bottom up.
func (p *LogicalExpand) ExtractFD() *fd.FDSet {
	// basically extract the children's fdSet.
	return p.LogicalSchemaProducer.ExtractFD()
}

// GetBaseLogicalPlan inherits BaseLogicalPlan.LogicalPlan.<23rd> implementation.

// ConvertOuterToInnerJoin inherits BaseLogicalPlan.LogicalPlan.<24th> implementation.

// *************************** end implementation of logicalPlan interface ***************************

// GetUsedCols extracts all of the Columns used by proj.
func (*LogicalExpand) GetUsedCols() (usedCols []*expression.Column) {
	// be careful that, expand OP itself, shouldn't output its own used cols, because
	// it just replicates the child's schema by defined grouping sets. (pass down what
	// the parent's used is enough here)
	return usedCols
}

// GenLevelProjections is used to generate level projections after all the necessary logical
// optimization is done such as column pruning.
func (p *LogicalExpand) GenLevelProjections() {
	// get all the grouping cols.
	groupingSetCols := p.RollupGroupingSets.AllSetsColIDs()
	p.DistinctSize, p.RollupGroupingIDs, p.RollupID2GIDS = p.RollupGroupingSets.DistinctSize()
	hasDuplicateGroupingSet := len(p.RollupGroupingSets) != p.DistinctSize
	schemaCols := p.Schema().Columns
	// last two schema col is about gid and gpos if any.
	nonGenCols := schemaCols[:len(schemaCols)-1]
	gidCol := schemaCols[len(schemaCols)-1]
	if hasDuplicateGroupingSet {
		// last two schema col is about gid and gpos.
		nonGenCols = schemaCols[:len(schemaCols)-2]
		gidCol = schemaCols[len(schemaCols)-2]
	}

	// for every rollup grouping set, gen its level projection.
	for offset, curGroupingSet := range p.RollupGroupingSets {
		levelProj := make([]expression.Expression, 0, p.Schema().Len())
		for _, oneCol := range nonGenCols {
			// if this col is in the grouping-set-cols and this col is not needed by current grouping-set, just set it as null value with specified fieldType.
			if groupingSetCols.Has(int(oneCol.UniqueID)) {
				if curGroupingSet.AllColIDs().Has(int(oneCol.UniqueID)) {
					// needed col in current grouping set: project it as col-ref.
					levelProj = append(levelProj, oneCol)
				} else {
					// un-needed col in current grouping set: project it as null value.
					nullValue := expression.NewNullWithFieldType(oneCol.RetType.Clone())
					levelProj = append(levelProj, nullValue)
				}
			} else {
				// other un-related cols: project it as col-ref.
				levelProj = append(levelProj, oneCol)
			}
		}
		// generate the grouping_id projection expr, project it as uint64.
		gid := p.GenerateGroupingIDModeBitAnd(curGroupingSet)
		if p.GroupingMode == tipb.GroupingMode_ModeNumericSet {
			gid = p.GenerateGroupingIDIncrementModeNumericSet(offset)
		}
		gidValue := expression.NewUInt64ConstWithFieldType(gid, gidCol.RetType.Clone())
		levelProj = append(levelProj, gidValue)

		// generate the grouping_pos projection expr, project it as uint64 if any.
		if hasDuplicateGroupingSet {
			gposCol := schemaCols[len(schemaCols)-1]
			// gpos value can equal the grouping set index offset.
			gpos := expression.NewUInt64ConstWithFieldType(uint64(offset), gposCol.RetType.Clone())
			// gen-col: project it as uint64.
			levelProj = append(levelProj, gpos)
		}
		p.LevelExprs = append(p.LevelExprs, levelProj)
	}
}

// GenerateGroupingMarks generate the groupingMark for the source column specified in grouping function.
func (p *LogicalExpand) GenerateGroupingMarks(sourceCols []*expression.Column) []map[uint64]struct{} {
	// Since grouping function may have multi args like grouping(a,b), so the source columns may greater than 1.
	// reference: https://dev.mysql.com/blog-archive/mysql-8-0-grouping-function/
	// Let's say GROUPING(b,a) group by a,b with rollup. (Note the b,a sequence is reversed from gby item)
	// if GROUPING (b,a) returns 3, it means that NULL in column “b” and NULL in column “a” for that row is
	// produce by a ROLLUP operation. If result is 2, NULL in column “a” alone is a result of ROLLUP operation.
	//
	// Formula: GROUPING(x,y,z) = GROUPING(x) << 2 + GROUPING(y) << 1 + GROUPING(z)
	//
	// so for the multi args GROUPING FUNCTION, we should return all the simple col grouping marks. When evaluating,
	// after all grouping marks are & with gid in sequence, the final res is derived as the formula said. This also
	// means that the grouping function accepts a maximum of 64 parameters.
	resSliceMap := make([]map[uint64]struct{}, 0, len(sourceCols))
	if p.GroupingMode == tipb.GroupingMode_ModeBitAnd {
		for _, oneCol := range sourceCols {
			resMap := make(map[uint64]struct{}, 1)
			res := uint64(0)
			// from high pos to low pos.
			for i := len(p.DistinctGroupByCol) - 1; i >= 0; i-- {
				// left shift.
				res = res << 1
				if p.DistinctGroupByCol[i].UniqueID == oneCol.UniqueID {
					// fill the corresponding col pos as 1 as bitMark.
					// eg: say distinctGBY [x,y,z] and GROUPING(x) with '100'.
					// When any groupingID & 100 > 0 means the source column x
					// is needed in this grouping set and is not grouped, so res = 0.
					res = res | 1
				}
			}
			resMap[res] = struct{}{}
			resSliceMap = append(resSliceMap, resMap)
		}
		return resSliceMap
	}
	// For GroupingMode_ModeNumericSet mode, for every simple col, its grouping marks is an id slice rather than a bit map.
	// For example, GROUPING(x,y,z) returns 6 it means: GROUPING(x) is 1, GROUPING(y) is 1 and GROUPING(z) is 0, in which
	// we should also return all these three single column grouping marks as function meta to GROUPING FUNCTION.
	for _, oneCol := range sourceCols {
		resSliceMap = append(resSliceMap, p.RollupID2GIDS[int(oneCol.UniqueID)])
	}
	return resSliceMap
}

// TrySubstituteExprWithGroupingSetCol is used to substitute the original gby expression with new gby col.
func (p *LogicalExpand) TrySubstituteExprWithGroupingSetCol(expr expression.Expression) (expression.Expression, bool) {
	// since all the original group items has been projected even single col,
	// let's check the origin gby expression here, and map it to new gby col.
	for i, oneExpr := range p.DistinctGbyExprs {
		if bytes.Equal(expr.CanonicalHashCode(), oneExpr.CanonicalHashCode()) {
			// found
			return p.DistinctGroupByCol[i], true
		}
	}
	// not found.
	return expr, false
}

// ResolveGroupingFuncArgsInGroupBy checks whether grouping function args is in grouping items.
func (p *LogicalExpand) ResolveGroupingFuncArgsInGroupBy(groupingFuncArgs []expression.Expression) ([]*expression.Column, error) {
	// build GBYColMap
	distinctGBYColMap := make(map[int64]struct{}, len(p.DistinctGroupByCol))
	for _, oneDistinctGBYCol := range p.DistinctGroupByCol {
		distinctGBYColMap[oneDistinctGBYCol.UniqueID] = struct{}{}
	}
	var refPos int
	rewrittenArgCols := make([]*expression.Column, 0, len(groupingFuncArgs))
	for argIdx, oneArg := range groupingFuncArgs {
		refPos = -1
		// since all the original group items has been projected even single col,
		// let's check the origin gby expression here, and map it to new gby col.
		for i, oneExpr := range p.DistinctGbyExprs {
			if bytes.Equal(oneArg.CanonicalHashCode(), oneExpr.CanonicalHashCode()) {
				refPos = i
				break
			}
		}
		if refPos != -1 {
			// directly ref original group by expressions.
			rewrittenArgCols = append(rewrittenArgCols, p.DistinctGroupByCol[refPos])
		} else {
			// case for refPos == -1
			// since for case like: select year from t group by year, country with rollup order by grouping(year)
			// when encountering build grouping(year), the args it received has already been substituted as grouping
			// set column year' rather than the original year anymore via first projection select item with pos 0. just check it!
			find := false
			if argCol, ok1 := oneArg.(*expression.Column); ok1 {
				if _, ok2 := distinctGBYColMap[argCol.UniqueID]; ok2 {
					rewrittenArgCols = append(rewrittenArgCols, argCol)
					find = true
				}
			}
			if !find {
				return nil, plannererrors.ErrFieldInGroupingNotGroupBy.GenWithStackByArgs(fmt.Sprintf("#%d", argIdx))
			}
		}
	}
	return rewrittenArgCols, nil
}

// GenerateGroupingIDModeBitAnd is used to generate convenient groupingID for quick computation of grouping function.
// A bit in the bitmask is corresponding to an attribute in the group by attributes sequence, the selected attribute
// has corresponding bit set to 0 and otherwise set to 1. Example, if we have GroupBy attributes(a,b,c,d), the bitmask
// 5 (whose binary form is 0101) represents grouping set (a,c).
func (p *LogicalExpand) GenerateGroupingIDModeBitAnd(oneSet expression.GroupingSet) uint64 {
	// say distinctGbyCols       :  a,     b,     c
	//       bit pos index       :  0,     1,     2
	// current grouping set is   :  {a, c}
	//                               +---- mark the corresponding pos as 1 then get --->     101
	//     for special case      :  {a,a,c} and {a,c}: this two logical same grouping set naturally share the same gid bits: 101
	idsNeeded := oneSet.AllColIDs()
	res := uint64(0)
	// from high pos to low pos.
	for i := len(p.DistinctGroupByCol) - 1; i >= 0; i-- {
		// left shift.
		res = res << 1
		if idsNeeded.Has(int(p.DistinctGroupByCol[i].UniqueID)) {
			// col is needed, fill the corresponding pos as 1.
			res = res | 1
		}
	}
	// how to use it, eg: when encountering a grouping function like: grouping(a), we can know the column a's pos index in distinctGbyCols
	// is about 0, then we can get the mask as 001 which will be returned back as this grouping function's meta when rewriting it, then we
	// can use the bit mask to BitAnd(OP) groupingID column when evaluating, when the result is not 0, then for this row, it's column 'a'
	// is not grouped, marking them as 0, otherwise marking them as 1.
	return res
}

// GenerateGroupingIDIncrementModeNumericSet is used to generate grouping ids when the num of grouping sets is greater than 64.
// Under this circumstance, bitAnd uint64 doesn't have enough capacity to set those bits, so incremental grouping ID set is chosen.
func (p *LogicalExpand) GenerateGroupingIDIncrementModeNumericSet(oneSetOffset int) uint64 {
	// say distinctGbyCols       :  a,     b,     c
	// say grouping sets         : {a,b,c}, {a,b},  {a},   {}    <----+  (store the mapping as grouping sets meta)
	// we can just set its gid   :  0,       1       2      3    <----+
	// just keep this mapping logic stored as meta, and return the defined id back generated from this defined rule.
	//     for special case      :  {a,a,c} and {a,c}: this two logical same grouping set naturally share the same gid allocation!
	return p.RollupGroupingIDs[oneSetOffset]
	// how to use it, eg: when encountering a grouping function like: grouping(a), we should dig down to related Expand operator and
	// found it in meta that: column 'a' is in grouping set {a,b,c}, {a,b},  {a}, and its correspondent mapping grouping ids is about
	// {0,1,2}. This grouping id set is returned back as this grouping function's specified meta when rewriting the grouping function,
	// and the evaluating logic is quite simple as IN compare.
}
